package com.lc.projects.unit.thread.part07.test1;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

//实现ThreadFactory接口为Fork/Join框架生成定制线程
public class Test7 {
	public static void main(String[] args) {
		MyWorkerThreadFactory factory = new MyWorkerThreadFactory();
		ForkJoinPool pool = new ForkJoinPool(4, factory, null, false);
		int[] array = new int[1000];
		for (int i = 0; i < array.length; i++) {
			array[i] = i;
		}
		
		MyRecursiveTask task = new MyRecursiveTask(array, 0, array.length);
		pool.execute(task);
		task.join();
		pool.shutdown();
		try {
			pool.awaitTermination(1, TimeUnit.HOURS);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
		
		try {
			System.out.println("Main:Result:"+task.get());
		} catch (InterruptedException | ExecutionException e) {
			e.printStackTrace();
		}
		System.out.println("Main:End.");
	}
}	

class MyWorkerThread extends ForkJoinWorkerThread{

	protected MyWorkerThread(ForkJoinPool pool) {
		super(pool);
	}
	
	private static ThreadLocal<Integer> taskCounter = new ThreadLocal<>();

	@Override
	protected void onStart() {
		super.onStart();
		System.out.println("MyWorkerThread: Initalizing task:" + getId()+" "+ taskCounter.get() );
		taskCounter.set(0);
	}

	@Override
	protected void onTermination(Throwable exception) {
		System.out.println("MyWorkerThread: " + getId()+" "+ taskCounter.get() );
		super.onTermination(exception);
	}
	
	
	public void addTask(){
		int counter = taskCounter.get().intValue();
		counter ++;
		taskCounter.set(counter);
	}
	
}


class MyWorkerThreadFactory implements ForkJoinWorkerThreadFactory{

	@Override
	public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
		return new MyWorkerThread(pool);
	}
	
}

class MyRecursiveTask extends RecursiveTask<Integer>{

	/**
	 * 
	 */
	private static final long serialVersionUID = 1L;
	
	private int[] array;
	private int start , end;
	
	

	public MyRecursiveTask(int[] array, int start, int end) {
		super();
		this.array = array;
		this.start = start;
		this.end = end;
	}



	@Override
	protected Integer compute() {
		Integer ret = 0;
		MyWorkerThread thread = (MyWorkerThread) Thread.currentThread();
		thread.addTask();
		
		for (int i = start; i < end; i++) {
			ret += array[i];
		}
		return ret;
	}
	
	
	private Integer addResults(MyRecursiveTask task1, MyRecursiveTask task2){
		int value;
		try {
			value = task1.get().intValue() + task2.get().intValue();
		} catch (InterruptedException | ExecutionException e) {
			e.printStackTrace();
			value = 0;
		}
		
		try {
			TimeUnit.MILLISECONDS.sleep(10);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
		
		return value;
	}
	
}